#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time    : 2018/6/21 10:15
# @Author  : Deyu Tian
# @Site    : 
# @File    : fireriskcal.py
# @Software: PyCharm Community Edition
import gdal, osr
import numpy as np

def calcFFSI(veg, ndvi, slope, aspect, imggt):
    """
    calc FFSI
    :param veg:
    :param ndvi:
    :param slope:
    :param aspect:
    :return:
    """
    x = veg.shape[0]
    y = veg.shape[1]
    print(x, y)
    ffsi = np.zeros_like(veg, dtype="float32")
    for i in range(0, x):
        for j in range(0, y):
            if veg[i, j] == 1:
                ffsi[i, j] += 0.290
            if veg[i, j] == 2:
                ffsi[i, j] += 5
            if ndvi[i, j] == 2:
                ffsi[i, j] += 3.125
            if ndvi[i, j] == 3:
                ffsi[i, j] += 1.25
            if ndvi[i, j] == 4:
                ffsi[i, j] += 4.4
            if ndvi[i, j] == 5:
                ffsi[i, j] += 0.431
            if slope[i, j] == 1:
                ffsi[i, j] += 1.305
            if slope[i, j] == 2:
                ffsi[i, j] += 0.217
            if aspect[i, j] == 1:
                ffsi[i, j] += 1.447
            if aspect[i, j] == 2:
                ffsi[i, j] += 0.568
            if aspect[i, j] == 3:
                ffsi[i, j] += 1.351
            if aspect[i, j] == 4:
                ffsi[i, j] += 0.493
    array2rasterUTM("D:\\PROJECT_2018\\firerisk\\result\\ffsi.tif", imggt, ffsi)



def img2array(img):
    """
    read dems to array by gdal
    :param imgfn path of geotiff
    :return narray of geotiff
    """
    img_data = gdal.Open(img)
    img_array = img_data.ReadAsArray()
    return img_array

def readgeotransform(img):
    """
    read geo transform for array2tiff trans
    :param imgfn path of geotiff
    :return narray of geotiff
    """
    img_data = gdal.Open(img)
    imggt = img_data.GetGeoTransform()
    return imggt

def array2rasterUTM(newRasterfn, panTransform, array):
    cols = array.shape[1]
    rows = array.shape[0]

    driver = gdal.GetDriverByName('GTiff')
    outRaster = driver.Create(newRasterfn, cols, rows, 1, gdal.GDT_Float32)
    outRaster.SetGeoTransform((panTransform[0], panTransform[1], panTransform[2], panTransform[3],
                               panTransform[4], panTransform[5]))
    outband = outRaster.GetRasterBand(1)
    outband.WriteArray(array)
    outRasterSRS = osr.SpatialReference()
    outRasterSRS.ImportFromEPSG(32651) #utm zone 51n
    outRaster.SetProjection(outRasterSRS.ExportToWkt())
    outband.FlushCache()

if __name__ == '__main__':
    data_path = "D:\\PROJECT_2018\\firerisk\\features"
    veg_path = "{}\\veg_resm.tif".format(data_path)
    ndvi_path = "{}\\ndvi_reclass.tif".format(data_path)
    slope_path = "{}\\ndvi_reclass_reproj1.tif".format(data_path)
    aspect_path = "{}\\aspect_reclass_reproj1.tif".format(data_path)
    veg = img2array(veg_path)
    veg = veg[1:, 1:]
    ndvi = img2array(ndvi_path)
    slope = img2array(slope_path)
    slope = slope[1:,]
    aspect = img2array(aspect_path)
    aspect = aspect[1:,]
    # print(veg.shape, ndvi.shape, slope.shape, aspect.shape)
    # print(np.unique(aspect))
    imggt = readgeotransform(ndvi_path)
    calcFFSI(veg, ndvi, slope, aspect, imggt)
    pass